#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ERNIE pretraining."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time
import multiprocessing

import paddle
import numpy as np
import paddle.fluid as fluid

from reader.pretraining import ErnieDataReader
from model.ernie_v1 import ErnieModel, ErnieConfig
from optimization import optimization
from utils.args import print_arguments, check_cuda
from utils.init import init_checkpoint, init_pretraining_params

from pretrain_args import parser
paddle.enable_static()
args = parser.parse_args()
num_seq = args.batch_size // args.max_seq_len
ascend=1
device=0
# yapf: enable.
def create_model(pyreader_name, ernie_config):
    shapes=[[-1, args.max_seq_len], [-1, args.max_seq_len],
            [-1, args.max_seq_len], [-1, args.max_seq_len, args.max_seq_len],
            [-1, 1], [-1, 1], [-1, 1]]
    dtypes=['int32', 'int32', 'int32', 'float32', 'int32', 'int32', 'int32']
    names = ['src_ids', 'pos_ids', 'sent_ids', 'input_mask', 'mask_label', 'mask_pos', 'label']

    inputs = [fluid.layers.data(name=names[i], shape=shapes[i], dtype=dtypes[i]) for i in range(len(names))]
    fluid.reader.keep_data_loader_order(False)
    pyreader = fluid.io.DataLoader.from_generator(
            feed_list=inputs,
            capacity=70, iterable=False)

    src_ids, pos_ids, sent_ids, input_mask, mask_label, mask_pos, labels = inputs[:7]

    ernie = ErnieModel(
        src_ids=src_ids,
        position_ids=pos_ids,
        sentence_ids=sent_ids,
        input_mask=input_mask,
        config=ernie_config,
        weight_sharing=args.weight_sharing,
        use_fp16=args.use_fp16)

    next_sent_acc, mask_lm_loss, total_loss = ernie.get_pretraining_output(
        mask_label, mask_pos, labels)

    return pyreader, next_sent_acc, mask_lm_loss, total_loss

def train(args):
    print("pretraining start")
    ernie_config = ErnieConfig(args.ernie_config_path)
    ernie_config.print_config()

    train_program = fluid.Program()
    startup_prog = fluid.Program()
    with fluid.program_guard(train_program, startup_prog):
        with fluid.unique_name.guard():
            train_pyreader, next_sent_acc, mask_lm_loss, total_loss = create_model(
                pyreader_name='train_reader', ernie_config=ernie_config)
            scheduled_lr = optimization(
                loss=total_loss,
                warmup_steps=args.warmup_steps,
                num_train_steps=args.num_train_steps,
                learning_rate=args.learning_rate,
                train_program=train_program,
                startup_prog=startup_prog,
                weight_decay=args.weight_decay,
                scheduler=args.lr_scheduler,
                use_fp16=args.use_fp16)
 
    if not ascend:
        place = fluid.CUDAPlace(device) if args.use_cuda else fluid.CPUPlace()
    else:
        place = fluid.NPUPlace(device)

    nccl2_num_trainers = 1
    nccl2_trainer_id = 0
    print("args.is_distributed:", args.is_distributed)

    exe = fluid.Executor(place)
    exe.run(startup_prog)

    if args.init_checkpoint and args.init_checkpoint != "":
        init_checkpoint(exe, args.init_checkpoint, train_program, args.use_fp16)

    data_reader = ErnieDataReader(
        filelist=args.train_filelist,
        batch_size=args.batch_size,
        vocab_path=args.vocab_path,
        voc_size=ernie_config['vocab_size'],
        epoch=args.epoch,
        max_seq_len=args.max_seq_len,
        generate_neg_sample=args.generate_neg_sample,
        use_fake=args.use_fake,
        seq_num=num_seq,
        ascend=ascend)
    
    train_exe = exe
    #train_pyreader.decorate_tensor_provider(data_reader.data_generator())
    train_pyreader.set_batch_generator(data_reader.data_generator())
    train_pyreader.start()
    steps = 0
    cost = []
    lm_cost = []
    acc = []
    time_begin = time.time()
    while steps < args.num_train_steps:
        try:
            steps += nccl2_num_trainers
            skip_steps = args.skip_steps * nccl2_num_trainers

            if nccl2_trainer_id != 0:
                train_exe.run(fetch_list=[])
                continue

            if steps % skip_steps != 0:
                train_exe.run(fetch_list=[])
            else:
                each_next_acc, each_mask_lm_cost, each_total_cost, np_lr = train_exe.run(
                    fetch_list=[
                        next_sent_acc, mask_lm_loss, total_loss, scheduled_lr
                    ], program=train_program)
                acc.extend(each_next_acc)
                lm_cost.extend(each_mask_lm_cost)
                cost.extend(each_total_cost)

                print("feed_queue size", train_pyreader.queue.size())
                time_end = time.time()
                used_time = time_end - time_begin
                epoch, current_file_index, total_file, current_file, mask_type = data_reader.get_progress(
                )
                print("current learning_rate:%f" % np_lr[0])
                print(
                    "epoch: %d, progress: %d/%d, step: %d, loss: %f, "
                    "ppl: %f, next_sent_acc: %f, speed: %f steps/s, file: %s, mask_type: %s"
                    % (epoch, current_file_index, total_file, steps,
                       np.mean(np.array(cost)),
                       np.mean(np.exp(np.array(lm_cost))),
                       np.mean(np.array(acc)), skip_steps / used_time,
                       current_file, mask_type))
                cost = []
                lm_cost = []
                acc = []
                time_begin = time.time()

            if steps % args.save_steps == 0:
                save_path = os.path.join(args.checkpoints, "step_" + str(steps))
                fluid.io.save_persistables(exe, save_path, train_program)

        except fluid.core.EOFException:
            train_pyreader.reset()
            break


if __name__ == '__main__':
    print_arguments(args)
    check_cuda(args.use_cuda)
    train(args)
